{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/merveenoyan/smollm/blob/main/vision/finetuning/Smol_VLM_FT.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nc0g2NLpUSGr"
      },
      "source": [
        "# Fine-tune SmolVLM on Visual Question Answering using Consumer GPU with QLoRA\n",
        "\n",
        "In this notebook we will fine-tune SmolVLM VQAv2 dataset. With this notebook you can also fine-tune Idefics3, since both models have the same model class/architecture.\n",
        "\n",
        "We will use some techniques in this notebook that will let you fine-tune the model on L4 with batch size of 4 only using around 16.4 GB of VRAM. We ran this notebook in that setup to test, but because we were able to afford A100 this notebook was last ran on an A100."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WIhA1lQ7j0kw"
      },
      "outputs": [],
      "source": [
        "!pip install -q accelerate datasets peft bitsandbytes tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XyJaqZZ3uYYl"
      },
      "outputs": [],
      "source": [
        "!pip install -q flash-attn --no-build-isolation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wAeMA0heVBjT"
      },
      "source": [
        "We will push out model to Hub so we need to authenticate ourselves."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yKd5xtSGj7cm"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import notebook_login\n",
        "\n",
        "notebook_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WRq8ve-LVAzU"
      },
      "source": [
        "In this notebook we will not do full fine-tuning but use QLoRA method, which loads an adapter to the quantized version of the model, saving space. If you want to do full fine-tuning, set `USE_LORA` and `USE_QLORA` to False. If you want to do LoRA, set `USE_QLORA` to False and `USE_LORA` to True."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "referenced_widgets": [
            "23d3d175e6e642c7abc2bce09b73cf4d",
            "db6ca8f47f274464b135909c907c946a",
            "d05822c6293c424fbf9df6ec0f6b532b",
            "05582fca18f443d6965776a721e30e9f",
            "3d8974fd1ba9415c8070c1eab8ad75cb",
            "648257c1b1c24e25a26355bddf75aa41",
            "afa9a31c6b7f45e082ae07dea4a2600e",
            "92232af543a4446cac53e4fcf3f4b6e1",
            "a5f06e59634f4edf9f3d9409846a2b31",
            "7ddfa8718bc24882ba2b50a899656107",
            "5983728a1c1e43edb4d16bee6ad40171",
            "dff574197f1f4466abb0eb46d36b8378"
          ]
        },
        "id": "b9CDMq0duYYn",
        "outputId": "65a4a5fa-fe4d-4243-b2d7-405a8aa81c04"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "23d3d175e6e642c7abc2bce09b73cf4d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "processor_config.json:   0%|          | 0.00/68.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "db6ca8f47f274464b135909c907c946a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "chat_template.json:   0%|          | 0.00/434 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d05822c6293c424fbf9df6ec0f6b532b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "preprocessor_config.json:   0%|          | 0.00/486 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "05582fca18f443d6965776a721e30e9f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3d8974fd1ba9415c8070c1eab8ad75cb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "648257c1b1c24e25a26355bddf75aa41",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "afa9a31c6b7f45e082ae07dea4a2600e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/3.52M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "92232af543a4446cac53e4fcf3f4b6e1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "added_tokens.json:   0%|          | 0.00/92.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a5f06e59634f4edf9f3d9409846a2b31",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/1.07k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some kwargs in processor config are unused and will not have any effect: image_seq_len. \n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7ddfa8718bc24882ba2b50a899656107",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/7.08k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "5983728a1c1e43edb4d16bee6ad40171",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/4.49G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dff574197f1f4466abb0eb46d36b8378",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(10536960, 2256809840)\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
        "from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration\n",
        "\n",
        "USE_LORA = False\n",
        "USE_QLORA = True\n",
        "SMOL = True\n",
        "\n",
        "model_id = \"HuggingFaceTB/SmolVLM-Base\" if SMOL else \"HuggingFaceM4/Idefics3-8B-Llama3\"\n",
        "\n",
        "processor = AutoProcessor.from_pretrained(\n",
        "    model_id\n",
        ")\n",
        "\n",
        "if USE_QLORA or USE_LORA:\n",
        "    lora_config = LoraConfig(\n",
        "        r=8,\n",
        "        lora_alpha=8,\n",
        "        lora_dropout=0.1,\n",
        "        target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],\n",
        "        use_dora=False if USE_QLORA else True,\n",
        "        init_lora_weights=\"gaussian\"\n",
        "    )\n",
        "    lora_config.inference_mode = False\n",
        "    if USE_QLORA:\n",
        "        bnb_config = BitsAndBytesConfig(\n",
        "            load_in_4bit=True,\n",
        "            bnb_4bit_use_double_quant=True,\n",
        "            bnb_4bit_quant_type=\"nf4\",\n",
        "            bnb_4bit_compute_dtype=torch.bfloat16\n",
        "        )\n",
        "\n",
        "    model = Idefics3ForConditionalGeneration.from_pretrained(\n",
        "        model_id,\n",
        "        quantization_config=bnb_config if USE_QLORA else None,\n",
        "        _attn_implementation=\"flash_attention_2\",\n",
        "        device_map=\"auto\"\n",
        "    )\n",
        "    model.add_adapter(lora_config)\n",
        "    model.enable_adapters()\n",
        "    model = prepare_model_for_kbit_training(model)\n",
        "    model = get_peft_model(model, lora_config)\n",
        "    print(model.get_nb_trainable_parameters())\n",
        "else:\n",
        "    model = Idefics3ForConditionalGeneration.from_pretrained(\n",
        "        model_id,\n",
        "        torch_dtype=torch.bfloat16,\n",
        "        _attn_implementation=\"flash_attention_2\",\n",
        "    ).to(DEVICE)\n",
        "\n",
        "    # if you'd like to only fine-tune LLM\n",
        "    for param in model.model.vision_model.parameters():\n",
        "        param.requires_grad = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WIVhpp0EyZO2"
      },
      "source": [
        "The model as is is holding 2.7 GB of GPU RAM 💗"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LMTtg3dl3NX2"
      },
      "source": [
        "## Loading the dataset and Preprocessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pWHMWTSZ3Pyr"
      },
      "source": [
        "We will load a small portion of the VQAv2 dataset. We are loading a small portion of the model for education purposes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "POOqKqYRka5O"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "ds = load_dataset('merve/vqav2-small', trust_remote_code=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Znf9vMo5rnSd"
      },
      "outputs": [],
      "source": [
        "split_ds = ds[\"validation\"].train_test_split(test_size=0.5)\n",
        "train_ds = split_ds[\"train\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FIDioFlRuYYn",
        "outputId": "79b697a7-d245-4fdc-b0e8-d9ffa8627953"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['multiple_choice_answer', 'question', 'image'],\n",
              "    num_rows: 10717\n",
              "})"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_ds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5nwMO3n0X7Hv"
      },
      "source": [
        "Let's write our data collating function. We will apply prompt template to have questions and answers together so model can learn to answer. Then we pass the formatted prompts and images to the processor which processes both."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e0krVLZ-wNMl"
      },
      "outputs": [],
      "source": [
        "image_token_id = processor.tokenizer.additional_special_tokens_ids[\n",
        "            processor.tokenizer.additional_special_tokens.index(\"<image>\")]\n",
        "\n",
        "def collate_fn(examples):\n",
        "  texts = []\n",
        "  images = []\n",
        "  for example in examples:\n",
        "      image = example[\"image\"]\n",
        "      if image.mode != 'RGB':\n",
        "        image = image.convert('RGB')\n",
        "      question = example[\"question\"]\n",
        "      answer = example[\"multiple_choice_answer\"]\n",
        "      messages = [\n",
        "          {\n",
        "              \"role\": \"user\",\n",
        "              \"content\": [\n",
        "                  {\"type\": \"text\", \"text\": \"Answer briefly.\"},\n",
        "                  {\"type\": \"image\"},\n",
        "                  {\"type\": \"text\", \"text\": question}\n",
        "              ]\n",
        "          },\n",
        "          {\n",
        "              \"role\": \"assistant\",\n",
        "              \"content\": [\n",
        "                  {\"type\": \"text\", \"text\": answer}\n",
        "              ]\n",
        "          }\n",
        "      ]\n",
        "      text = processor.apply_chat_template(messages, add_generation_prompt=False)\n",
        "      texts.append(text.strip())\n",
        "      images.append([image])\n",
        "\n",
        "  batch = processor(text=texts, images=images, return_tensors=\"pt\", padding=True)\n",
        "  labels = batch[\"input_ids\"].clone()\n",
        "  labels[labels == processor.tokenizer.pad_token_id] = -100\n",
        "  labels[labels == image_token_id] = -100\n",
        "  batch[\"labels\"] = labels\n",
        "\n",
        "  return batch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kEYDjWpE3LD5"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QvAs896cdwg8"
      },
      "source": [
        "We can now initialize `Trainer` and initialize `TrainingArguments` to pass to `Trainer`.\n",
        "\n",
        "Some notes:\n",
        "- If you use 8-bit QLoRA with the below setup it uses around 16.4 GB VRAM (beautiful, fits comfortably inside L4, Colab free tier)\n",
        "- We use gradient accumulation to simulate a larger batch size.\n",
        "- We also save up on memory from intermediate activations by using gradient checkpointing.\n",
        "\n",
        "**Disclaimer:**\n",
        "The techniques here aren't free lunch. The latter two will add additional compute to the training, thus slow down a bit (for reference on two A100s with bsz of 16, we were able to train for 2 hrs 43 mins with the gradient accumulation steps of 4, disabling it reduced it with 2 hr 35 mins).\n",
        "If you want to speed-up, you might play around, reduce to 4-bit precision and have a higher batch size. Note that 4-bit might result in model learning less."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QNE2yWAYrAhD"
      },
      "outputs": [],
      "source": [
        "from transformers import TrainingArguments, Trainer\n",
        "\n",
        "model_name = model_id.split(\"/\")[-1]\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    num_train_epochs=1,\n",
        "    per_device_train_batch_size=4,\n",
        "    gradient_accumulation_steps=4,\n",
        "    warmup_steps=50,\n",
        "    learning_rate=1e-4,\n",
        "    weight_decay=0.01,\n",
        "    logging_steps=25,\n",
        "    save_strategy=\"steps\",\n",
        "    save_steps=250,\n",
        "    save_total_limit=1,\n",
        "    optim=\"paged_adamw_8bit\", # for 8-bit, keep this, else adamw_hf\n",
        "    bf16=True, # underlying precision for 8bit\n",
        "    output_dir=f\"./{model_name}-vqav2\",\n",
        "    hub_model_id=f\"{model_name}-vqav2\",\n",
        "    report_to=\"tensorboard\",\n",
        "    remove_unused_columns=False,\n",
        "    gradient_checkpointing=True\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oBBSDpBhreJd",
        "outputId": "071ed677-1d9f-4f98-9d19-64834440c9c4"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
          ]
        }
      ],
      "source": [
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    data_collator=collate_fn,\n",
        "    train_dataset=train_ds,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_QOCpw_-uYYo",
        "outputId": "7abb6937-c072-435a-c3f5-6dbb5b0b9eea"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='9' max='670' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [  9/670 01:41 < 2:39:41, 0.07 it/s, Epoch 0.01/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0hN0QD9_uYYo"
      },
      "outputs": [],
      "source": [
        "trainer.push_to_hub()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": [],
      "name": "Smol_VLM_FT.ipynb",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}